"""
Phase 9: Formal OOS calibration following Lunsford & West (2024).

Tests:
1. Random walk horse race: does the demographic model beat a random walk
   at 10-year and 25-year horizons on real safe rates?
2. PIT (Probability Integral Transform) calibration: do forecast intervals
   have correct coverage?
3. Interval score comparison across models.

Motivated by Lunsford & West, "An Empirical Evaluation of Some Long-Horizon
Macroeconomic Forecasts", Cleveland Fed WP 24-20.
"""

import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats
import sys
sys.path.insert(0, str(Path("/mnt/c/demographics_capital_flows/multilateral/69_country/src")))

PROJECT = Path("/mnt/c/demographics_capital_flows/monetary")
TABLE_DIR = PROJECT / "output" / "tables"

panel = pd.read_csv(PROJECT / "data" / "processed" / "monetary_panel.csv")


def prepare_rate_panel():
    """Get OECD bond yield panel with demographics."""
    df = panel[panel['real_bond_10y'].notna()].copy()
    df = df[df['year'] <= 2024]
    # Need Z_1, Z_2, Z_3 + controls + real_bond_10y
    cols = ['iso3', 'year', 'real_bond_10y', 'Z_1', 'Z_2', 'Z_3',
            'rgdp_growth', 'inflation', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag']
    df = df[cols].dropna()
    print(f"Rate panel: {len(df)} obs, {df.iso3.nunique()} countries, {df.year.min()}-{df.year.max()}")
    return df


def random_walk_horse_race(df, horizons=[10, 25]):
    """Compare demographic model vs random walk at long horizons."""
    print("\n" + "=" * 70)
    print("TEST 1: Random Walk Horse Race at Long Horizons")
    print("=" * 70)

    results = []
    countries = sorted(df.iso3.unique())

    for h in horizons:
        demo_errors = []
        rw_errors = []
        naive_errors = []

        for c in countries:
            cdf = df[df['iso3'] == c].sort_values('year')
            years = cdf['year'].values
            rates = cdf['real_bond_10y'].values

            for i in range(len(cdf)):
                origin_year = years[i]
                target_year = origin_year + h
                # Find target observation
                target_idx = np.where(years == target_year)[0]
                if len(target_idx) == 0:
                    continue
                target_idx = target_idx[0]

                actual = rates[target_idx]

                # Random walk forecast: rate stays at current level
                rw_forecast = rates[i]

                # Demographic model forecast: use origin-year demographics to predict
                # We fit cross-sectional regression at origin year, predict at target
                origin_data = df[df['year'] == origin_year].dropna(subset=['real_bond_10y', 'Z_1'])
                if len(origin_data) < 10:
                    continue

                y = origin_data['real_bond_10y'].values
                X = np.column_stack([
                    np.ones(len(origin_data)),
                    origin_data['Z_1'].values,
                    origin_data['Z_2'].values,
                    origin_data['Z_3'].values,
                ])
                try:
                    beta = np.linalg.lstsq(X, y, rcond=None)[0]
                except:
                    continue

                # Predict using TARGET year demographics
                target_data = df[(df['iso3'] == c) & (df['year'] == target_year)]
                if len(target_data) == 0:
                    continue
                x_target = np.array([1, target_data.iloc[0]['Z_1'],
                                     target_data.iloc[0]['Z_2'],
                                     target_data.iloc[0]['Z_3']])
                demo_forecast = x_target @ beta

                # Naive forecast: sample mean at origin
                naive_forecast = y.mean()

                demo_errors.append((actual - demo_forecast) ** 2)
                rw_errors.append((actual - rw_forecast) ** 2)
                naive_errors.append((actual - naive_forecast) ** 2)

        if len(demo_errors) == 0:
            print(f"\n  h={h}: No valid forecast pairs")
            continue

        rmse_demo = np.sqrt(np.mean(demo_errors))
        rmse_rw = np.sqrt(np.mean(rw_errors))
        rmse_naive = np.sqrt(np.mean(naive_errors))

        # Relative RMSE
        rel_demo_rw = rmse_demo / rmse_rw
        rel_demo_naive = rmse_demo / rmse_naive

        print(f"\n  Horizon h={h} years ({len(demo_errors)} forecast pairs)")
        print(f"    Demographic model RMSE: {rmse_demo:.3f}")
        print(f"    Random walk RMSE:       {rmse_rw:.3f}")
        print(f"    Sample mean RMSE:       {rmse_naive:.3f}")
        print(f"    Relative RMSE (demo/RW): {rel_demo_rw:.3f} ({'BEATS RW' if rel_demo_rw < 1 else 'RW wins'})")
        print(f"    Relative RMSE (demo/mean): {rel_demo_naive:.3f}")

        # Diebold-Mariano test (demo vs RW)
        d = np.array(demo_errors) - np.array(rw_errors)
        dm_stat = np.mean(d) / (np.std(d, ddof=1) / np.sqrt(len(d)))
        dm_p = 2 * (1 - stats.norm.cdf(abs(dm_stat)))
        print(f"    Diebold-Mariano (demo vs RW): DM={dm_stat:.3f}, p={dm_p:.3f}")

        results.append({
            'Horizon': f'{h}yr',
            'N_pairs': len(demo_errors),
            'RMSE_demo': rmse_demo,
            'RMSE_rw': rmse_rw,
            'RMSE_naive': rmse_naive,
            'Rel_RMSE_demo_rw': rel_demo_rw,
            'DM_stat': dm_stat,
            'DM_p': dm_p,
        })

    return pd.DataFrame(results)


def pit_calibration(df, horizons=[10]):
    """Probability Integral Transform calibration test."""
    print("\n" + "=" * 70)
    print("TEST 2: PIT Calibration (Lunsford-West Framework)")
    print("=" * 70)

    results = []
    countries = sorted(df.iso3.unique())

    for h in horizons:
        pit_values = []

        for c in countries:
            cdf = df[df['iso3'] == c].sort_values('year')
            years = cdf['year'].values
            rates = cdf['real_bond_10y'].values

            for i in range(len(cdf)):
                origin_year = years[i]
                target_year = origin_year + h
                target_idx = np.where(years == target_year)[0]
                if len(target_idx) == 0:
                    continue
                target_idx = target_idx[0]
                actual = rates[target_idx]

                # Fit cross-sectional model at origin year
                origin_data = df[df['year'] == origin_year].dropna(subset=['real_bond_10y', 'Z_1'])
                if len(origin_data) < 10:
                    continue

                y_train = origin_data['real_bond_10y'].values
                X_train = np.column_stack([
                    np.ones(len(origin_data)),
                    origin_data['Z_1'].values,
                    origin_data['Z_2'].values,
                    origin_data['Z_3'].values,
                ])
                try:
                    beta = np.linalg.lstsq(X_train, y_train, rcond=None)[0]
                except:
                    continue

                # Prediction for target
                target_row = df[(df['iso3'] == c) & (df['year'] == target_year)]
                if len(target_row) == 0:
                    continue
                x_target = np.array([1, target_row.iloc[0]['Z_1'],
                                     target_row.iloc[0]['Z_2'],
                                     target_row.iloc[0]['Z_3']])
                y_hat = x_target @ beta

                # Forecast error variance: in-sample residual variance
                resid = y_train - X_train @ beta
                sigma2 = np.var(resid, ddof=len(beta))

                # PIT: CDF of actual under forecast distribution N(y_hat, sigma2)
                pit = stats.norm.cdf(actual, loc=y_hat, scale=np.sqrt(sigma2))
                pit_values.append(pit)

        if len(pit_values) < 20:
            print(f"\n  h={h}: Insufficient PIT values ({len(pit_values)})")
            continue

        pit_arr = np.array(pit_values)

        # KS test: PIT should be uniform(0,1) if well-calibrated
        ks_stat, ks_p = stats.kstest(pit_arr, 'uniform')

        # Coverage rates at nominal levels
        coverage = {}
        for alpha in [0.50, 0.80, 0.90, 0.95]:
            tail = (1 - alpha) / 2
            covered = np.mean((pit_arr > tail) & (pit_arr < 1 - tail))
            coverage[f'{int(alpha*100)}%'] = covered

        print(f"\n  Horizon h={h} years ({len(pit_values)} PIT values)")
        print(f"    KS test (H0: uniform): stat={ks_stat:.3f}, p={ks_p:.3f}")
        print(f"    {'Well-calibrated' if ks_p > 0.05 else 'MISCALIBRATED'} at 5% level")
        print(f"    Coverage rates (nominal → actual):")
        for nom, act in coverage.items():
            direction = "↑" if act > float(nom.strip('%'))/100 else "↓"
            print(f"      {nom} → {act:.1%} {direction}")

        # PIT histogram (text-based)
        print(f"    PIT histogram (should be flat):")
        bins = np.linspace(0, 1, 11)
        counts, _ = np.histogram(pit_arr, bins=bins)
        expected = len(pit_arr) / 10
        for j in range(10):
            bar = '#' * int(counts[j] / expected * 20)
            print(f"      [{bins[j]:.1f}-{bins[j+1]:.1f}]: {counts[j]:3d} {bar}")

        results.append({
            'Horizon': f'{h}yr',
            'N': len(pit_values),
            'KS_stat': ks_stat,
            'KS_p': ks_p,
            'Coverage_50': coverage['50%'],
            'Coverage_80': coverage['80%'],
            'Coverage_90': coverage['90%'],
            'Coverage_95': coverage['95%'],
        })

    return pd.DataFrame(results)


def expanding_window_oos(df):
    """Expanding-window OOS test: train on data up to year t, predict t+1."""
    print("\n" + "=" * 70)
    print("TEST 3: Expanding-Window 1-Year-Ahead OOS")
    print("=" * 70)

    min_train_years = 15
    all_years = sorted(df.year.unique())

    demo_sq_errors = []
    rw_sq_errors = []
    mean_sq_errors = []
    pred_years = []

    for t_idx in range(min_train_years, len(all_years) - 1):
        train_end = all_years[t_idx]
        test_year = all_years[t_idx + 1]

        train = df[df['year'] <= train_end]
        test = df[df['year'] == test_year]

        if len(train) < 50 or len(test) < 5:
            continue

        # Fit demographic model on training data
        rhs = ['Z_1', 'Z_2', 'Z_3', 'rgdp_growth', 'inflation', 'fiscal_bal_gdp']
        train_clean = train[['real_bond_10y'] + rhs].dropna()
        test_clean = test[['real_bond_10y', 'iso3'] + rhs].dropna()

        if len(train_clean) < 30 or len(test_clean) < 3:
            continue

        y_train = train_clean['real_bond_10y'].values
        X_train = np.column_stack([np.ones(len(train_clean))] +
                                  [train_clean[c].values for c in rhs])

        beta = np.linalg.lstsq(X_train, y_train, rcond=None)[0]

        for _, row in test_clean.iterrows():
            actual = row['real_bond_10y']
            x = np.array([1] + [row[c] for c in rhs])
            demo_pred = x @ beta

            # Random walk: use previous year's rate for this country
            prev = df[(df['iso3'] == row['iso3']) & (df['year'] == train_end)]
            if len(prev) == 0:
                continue
            rw_pred = prev.iloc[0]['real_bond_10y']

            # Unconditional mean
            mean_pred = y_train.mean()

            demo_sq_errors.append((actual - demo_pred) ** 2)
            rw_sq_errors.append((actual - rw_pred) ** 2)
            mean_sq_errors.append((actual - mean_pred) ** 2)
            pred_years.append(test_year)

    if len(demo_sq_errors) == 0:
        print("  No valid predictions")
        return None

    rmse_demo = np.sqrt(np.mean(demo_sq_errors))
    rmse_rw = np.sqrt(np.mean(rw_sq_errors))
    rmse_mean = np.sqrt(np.mean(mean_sq_errors))

    print(f"\n  Expanding window: {len(demo_sq_errors)} predictions "
          f"({min(pred_years)}-{max(pred_years)})")
    print(f"    Demographic model RMSE: {rmse_demo:.3f}")
    print(f"    Random walk RMSE:       {rmse_rw:.3f}")
    print(f"    Unconditional mean RMSE: {rmse_mean:.3f}")
    print(f"    Relative RMSE (demo/RW): {rmse_demo/rmse_rw:.3f}")

    # DM test
    d = np.array(demo_sq_errors) - np.array(rw_sq_errors)
    dm_stat = np.mean(d) / (np.std(d, ddof=1) / np.sqrt(len(d)))
    dm_p = 2 * (1 - stats.norm.cdf(abs(dm_stat)))
    print(f"    Diebold-Mariano: DM={dm_stat:.3f}, p={dm_p:.3f}")

    # Pre/post GFC split
    pre = [i for i, y in enumerate(pred_years) if y <= 2008]
    post = [i for i, y in enumerate(pred_years) if y > 2008]

    if pre:
        print(f"\n    Pre-GFC ({min(pred_years[i] for i in pre)}-2008):")
        print(f"      Demo RMSE: {np.sqrt(np.mean([demo_sq_errors[i] for i in pre])):.3f}")
        print(f"      RW RMSE:   {np.sqrt(np.mean([rw_sq_errors[i] for i in pre])):.3f}")
    if post:
        print(f"    Post-GFC (2009-{max(pred_years[i] for i in post)}):")
        print(f"      Demo RMSE: {np.sqrt(np.mean([demo_sq_errors[i] for i in post])):.3f}")
        print(f"      RW RMSE:   {np.sqrt(np.mean([rw_sq_errors[i] for i in post])):.3f}")

    return {
        'RMSE_demo': rmse_demo, 'RMSE_rw': rmse_rw, 'RMSE_mean': rmse_mean,
        'Rel_RMSE': rmse_demo / rmse_rw, 'DM_stat': dm_stat, 'DM_p': dm_p,
        'N': len(demo_sq_errors),
    }


def save_results(rw_results, pit_results, ew_results):
    """Save combined results as markdown table."""
    lines = []
    lines.append("# Table 22: Out-of-Sample Calibration Tests")
    lines.append("")
    lines.append("*Following Lunsford & West (2024) framework*")
    lines.append("")

    lines.append("## Panel A: Random Walk Horse Race (Long Horizons)")
    lines.append("")
    lines.append("| Horizon | N Pairs | RMSE Demo | RMSE RW | RMSE Mean | Rel RMSE (Demo/RW) | DM stat | DM p |")
    lines.append("|---------|---------|-----------|---------|-----------|-------------------|---------|------|")
    for _, row in rw_results.iterrows():
        lines.append(f"| {row['Horizon']} | {int(row['N_pairs'])} | {row['RMSE_demo']:.3f} | "
                     f"{row['RMSE_rw']:.3f} | {row['RMSE_naive']:.3f} | {row['Rel_RMSE_demo_rw']:.3f} | "
                     f"{row['DM_stat']:.3f} | {row['DM_p']:.3f} |")
    lines.append("")

    lines.append("## Panel B: PIT Calibration")
    lines.append("")
    lines.append("| Horizon | N | KS stat | KS p | Cov 50% | Cov 80% | Cov 90% | Cov 95% |")
    lines.append("|---------|---|---------|------|---------|---------|---------|---------|")
    for _, row in pit_results.iterrows():
        lines.append(f"| {row['Horizon']} | {int(row['N'])} | {row['KS_stat']:.3f} | {row['KS_p']:.3f} | "
                     f"{row['Coverage_50']:.1%} | {row['Coverage_80']:.1%} | "
                     f"{row['Coverage_90']:.1%} | {row['Coverage_95']:.1%} |")
    lines.append("")

    if ew_results:
        lines.append("## Panel C: Expanding-Window 1-Year-Ahead")
        lines.append("")
        lines.append("| Metric | Value |")
        lines.append("|--------|-------|")
        lines.append(f"| RMSE Demographic Model | {ew_results['RMSE_demo']:.3f} |")
        lines.append(f"| RMSE Random Walk | {ew_results['RMSE_rw']:.3f} |")
        lines.append(f"| RMSE Unconditional Mean | {ew_results['RMSE_mean']:.3f} |")
        lines.append(f"| Relative RMSE (Demo/RW) | {ew_results['Rel_RMSE']:.3f} |")
        lines.append(f"| Diebold-Mariano stat | {ew_results['DM_stat']:.3f} |")
        lines.append(f"| Diebold-Mariano p | {ew_results['DM_p']:.3f} |")
        lines.append(f"| N predictions | {ew_results['N']} |")
    lines.append("")
    lines.append("*Panel A: Cross-sectional demographic model estimated at origin year, applied to target-year demographics.*")
    lines.append("*Panel B: PIT values should be uniform(0,1) if forecast distribution is well-calibrated. KS p > 0.05 = well-calibrated.*")
    lines.append("*Panel C: Model estimated on all data up to year t, predicts year t+1. Controls: Z₁-Z₃, GDP growth, inflation, fiscal balance.*")

    outpath = TABLE_DIR / "phase9_table22_oos_calibration.md"
    outpath.write_text('\n'.join(lines))
    print(f"\nSaved: {outpath}")


if __name__ == '__main__':
    df = prepare_rate_panel()
    rw_results = random_walk_horse_race(df, horizons=[10, 25])
    pit_results = pit_calibration(df, horizons=[10, 25])
    ew_results = expanding_window_oos(df)
    save_results(rw_results, pit_results, ew_results)
    print("\nDone.")
